/*
 * Copyright (C) 2015 Edward Raff
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */
package com.jstarcraft.ai.jsat.datatransform.visualization;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.concurrent.CountDownLatch;

import com.jstarcraft.ai.jsat.DataSet;
import com.jstarcraft.ai.jsat.classifiers.DataPoint;
import com.jstarcraft.ai.jsat.datatransform.DataTransform;
import com.jstarcraft.ai.jsat.datatransform.PCA;
import com.jstarcraft.ai.jsat.distributions.Normal;
import com.jstarcraft.ai.jsat.linear.DenseVector;
import com.jstarcraft.ai.jsat.linear.Vec;
import com.jstarcraft.ai.jsat.linear.VecPaired;
import com.jstarcraft.ai.jsat.linear.distancemetrics.DistanceMetric;
import com.jstarcraft.ai.jsat.linear.distancemetrics.EuclideanDistance;
import com.jstarcraft.ai.jsat.linear.vectorcollection.DefaultVectorCollection;
import com.jstarcraft.ai.jsat.linear.vectorcollection.VectorCollection;
import com.jstarcraft.ai.jsat.math.FastMath;
import com.jstarcraft.ai.jsat.math.optimization.stochastic.Adam;
import com.jstarcraft.ai.jsat.math.optimization.stochastic.GradientUpdater;
import com.jstarcraft.ai.jsat.math.rootfinding.Zeroin;
import com.jstarcraft.ai.jsat.utils.SystemInfo;
import com.jstarcraft.ai.jsat.utils.concurrent.AtomicDouble;
import com.jstarcraft.ai.jsat.utils.concurrent.ParallelUtils;
import com.jstarcraft.ai.jsat.utils.random.RandomUtil;

/**
 * t-distributed Stochastic Neighbor Embedding is an algorithm for creating low
 * dimensional embeddings of datasets, for the purpose of visualization. It
 * attempts to keep points that are near each other in the original space near
 * each other in the low dimensional space as well, with less emphasis on
 * maintaining far-away relationships in the data. This implementation uses the
 * approximated gradients to learn the embedding in O(n log n) time.<br>
 * <br>
 * If the input dataset has a dimension greater than 50, it is advisable to
 * project the data set down to 50 dimensions using {@link PCA} or some similar
 * technique.<br>
 * <br>
 * See:<br>
 * <ul>
 * <li>Maaten, L. Van Der, & Hinton, G. (2008). <i>Visualizing Data using
 * t-SNE</i>. Journal of Machine Learning Research, 9, 2579–2605.</li>
 * <li>Van der Maaten, L. (2014). <i>Accelerating t-SNE using Tree-Based
 * Algorithms</i>. Journal of Machine Learning Research, 15, 3221–3245.
 * Retrieved from
 * <a href="http://jmlr.org/papers/v15/vandermaaten14a.html">here</a></li>
 * </ul>
 *
 * @author Edward Raff
 */
public class TSNE implements VisualizationTransform {
    private double alpha = 4;
    private double exageratedPortion = 0.25;
    private DistanceMetric dm = new EuclideanDistance();
    private int T = 1000;
    private double perplexity = 30;
    private double theta = 0.5;
    /**
     * The target embedding dimension, hard coded to 2 for now
     */
    private int s = 2;

    /**
     * &alpha; is the "early exaggeration" constant. It is a multiple applied to
     * part of the gradient for th first quarter of iterations, and can improve the
     * quality of the solution found. A value in the range of [4, 20] is
     * recommended.
     *
     * @param alpha the exaggeration constant
     */
    public void setAlpha(double alpha) {
        if (alpha <= 0 || Double.isNaN(alpha) || Double.isInfinite(alpha))
            throw new IllegalArgumentException("alpha must be positive, not " + alpha);
        this.alpha = alpha;
    }

    /**
     * 
     * @return the "early exaggeration" constant
     */
    public double getAlpha() {
        return alpha;
    }

    /**
     * Sets the target perplexity of the gaussian used over each data point. The
     * perplexity can be thought of as a quasi desired number of nearest neighbors
     * to be considered, but is adapted based on the distribution of the data.
     * Increasing the perplexity can increase the amount of time it takes to get an
     * embedding. Using a value in the range of [5, 50] is recommended.
     *
     * @param perplexity the quasi number of neighbors to consider for each data
     *                   point
     */
    public void setPerplexity(double perplexity) {
        if (perplexity <= 0 || Double.isNaN(perplexity) || Double.isInfinite(perplexity))
            throw new IllegalArgumentException("perplexity must be positive, not " + perplexity);
        this.perplexity = perplexity;
    }

    /**
     * 
     * @return the target perplexity to use for each data point
     */
    public double getPerplexity() {
        return perplexity;
    }

    /**
     * Sets the desired number of gradient descent iterations to perform.
     * 
     * @param T the number of gradient descent iterations
     */
    public void setIterations(int T) {
        if (T <= 1)
            throw new IllegalArgumentException("number of iterations must be positive, not " + T);
        this.T = T;
    }

    /**
     * 
     * @return the number of gradient descent iterations to perform
     */
    public int getIterations() {
        return T;
    }

    @Override
    public <Type extends DataSet> Type transform(DataSet<Type> d, boolean parallel) {
        Random rand = RandomUtil.getRandom();
        final int N = d.size();
        // If perp set too big, the search size would be larger than the dataset size.
        // So min to N
        /**
         * form sec 4.1: "we compute the sparse approximation by finding the floor(3u)
         * nearest neighbors of each of the N input objects (recall that u is the
         * perplexity of the conditional distributions)"
         */
        final int knn = (int) Math.min(Math.floor(3 * perplexity), N - 1);

        /**
         * P_ij does not change at this point, so lets compute these values only once
         * please! j index matches up to the value stored in nearMe
         */
        final double[][] nearMePij = new double[N][knn];

        /**
         * Each row is the set of 3*u indices returned by the NN search
         */
        final int[][] nearMe = new int[N][knn];

        computeP(d, parallel, rand, knn, nearMe, nearMePij, dm, perplexity);

        Normal normalDIst = new Normal(0, 1e-4);
        /**
         * For now store all data in a 2d array to avoid excessive overhead / cache
         * missing
         */
        final double[] y = normalDIst.sample(N * s, rand);

        final double[] y_grad = new double[y.length];

        // vec wraped version for convinence
        final Vec y_vec = DenseVector.toDenseVec(y);
        final Vec y_grad_vec = DenseVector.toDenseVec(y_grad);

        GradientUpdater gradUpdater = new Adam();
        gradUpdater.setup(y.length);

        for (int iter = 0; iter < T; iter++)// optimization
        {
            final int ITER = iter;

            Arrays.fill(y_grad, 0);

            // First loop for the F_rep forces, we do this first to normalize so we can use
            // 1 work space for the gradient
            final Quadtree qt = new Quadtree(y);

            // TODO might not result in even load
            double Z = ParallelUtils.run(parallel, N, (start, end) -> {
                double[] workSpace = new double[s];
                double local_Z = 0;
                for (int i = start; i < end; i++) {
                    Arrays.fill(workSpace, 0.0);
                    local_Z += computeF_rep(qt.root, i, y, workSpace);

                    // should be multiplied by 4, rolling it into the normalization by Z after
                    for (int k = 0; k < s; k++)
                        inc_z_ij(workSpace[k], i, k, y_grad, s);
                }
                return local_Z;
            }, (a, b) -> a + b);

            // normalize by Z
            final double zNorm = 4.0 / (Z + 1e-13);
            for (int i = 0; i < y.length; i++)
                y_grad[i] *= zNorm;

            // This second loops computes the F_attr forces
            final CountDownLatch latch_g1 = new CountDownLatch(SystemInfo.LogicalCores);

            ParallelUtils.run(parallel, N, (start, end) -> {
                for (int i = start; i < end; i++)// N
                {
                    for (int j_indx = 0; j_indx < knn; j_indx++) // O(u)
                    {
                        int j = nearMe[i][j_indx];
                        if (i == j)// this should never happen b/c we skipped that when creating nearMe
                            continue;
                        double pij = nearMePij[i][j_indx];
                        if (ITER < T * exageratedPortion)
                            pij *= alpha;
                        double cnst = pij * q_ijZ(i, j, y, s) * 4;

                        for (int k = 0; k < s; k++) {
                            double diff = z_ij(i, k, y, s) - z_ij(j, k, y, s);
                            inc_z_ij(cnst * diff, i, k, y_grad, s);
                        }
                    }
                }
            });

            // now we have accumulated all gradients
            double eta = 200;

            gradUpdater.update(y_vec, y_grad_vec, eta);
        }

        DataSet<Type> transformed = d.shallowClone();

        final IdentityHashMap<DataPoint, Integer> indexMap = new IdentityHashMap<>(N);
        for (int i = 0; i < N; i++)
            indexMap.put(d.getDataPoint(i), i);

        transformed.applyTransform(new DataTransform() {

            @Override
            public DataPoint transform(DataPoint dp) {
                int i = indexMap.get(dp);
                DenseVector dv = new DenseVector(s);
                for (int k = 0; k < s; k++)
                    dv.set(k, y[i * 2 + k]);

                return new DataPoint(dv, dp.getCategoricalValues(), dp.getCategoricalData());
            }

            @Override
            public void fit(DataSet data) {
            }

            @Override
            public DataTransform clone() {
                return this;
            }
        });

        return (Type) transformed;
    }

    /**
     * 
     * @param d          the dataset to search
     * @param parallel   {@code true} if computation should be done with multiple
     *                   threads, {@code false} for single threaded
     * @param rand       source of randomness
     * @param knn        the number of neighbors to search for
     * @param nearMe     each row is the set of knn indices returned by the NN
     *                   search
     * @param nearMePij  the symmetrized neighbor probability
     * @param dm         the distance metric to use for determining closeness
     * @param perplexity the perplexity value for the effective nearest neighbor
     *                   search and weighting
     */
    protected static void computeP(DataSet d, boolean parallel, Random rand, final int knn, final int[][] nearMe, final double[][] nearMePij, final DistanceMetric dm, final double perplexity) {
        @SuppressWarnings("unchecked")
        final List<Vec> vecs = d.getDataVectors();
        final List<Double> accelCache = dm.getAccelerationCache(vecs, parallel);
        final int N = vecs.size();

        final VectorCollection<Vec> vc = new DefaultVectorCollection<>();
        vc.build(parallel, vecs, dm);

        final List<List<? extends VecPaired<Vec, Double>>> neighbors = new ArrayList<>(N);
        for (int i = 0; i < N; i++)
            neighbors.add(null);

        // new scope b/c I don't want to leark the silly vecIndex thing
        {
            // Used to map vecs back to their index so we can store only the ones we need in
            // nearMe
            final IdentityHashMap<Vec, Integer> vecIndex = new IdentityHashMap<>(N);
            for (int i = 0; i < N; i++)
                vecIndex.put(vecs.get(i), i);

            ParallelUtils.run(parallel, N, (i) -> {
                Vec x_i = vecs.get(i);
                List<? extends VecPaired<Vec, Double>> closest = vc.search(x_i, knn + 1);// +1 b/c self is closest
                neighbors.set(i, closest);
                for (int j = 1; j < closest.size(); j++) {
                    nearMe[i][j - 1] = vecIndex.get(closest.get(j).getVector());
                }
            });
        }
        // Now lets figure out everyone's sigmas
        final double[] sigma = new double[N];

        final AtomicDouble minSigma = new AtomicDouble(Double.POSITIVE_INFINITY);
        final AtomicDouble maxSigma = new AtomicDouble(0);

        for (int i = 0; i < N; i++)// first lets figure out a min/max range
        {
            List<? extends VecPaired<Vec, Double>> n_i = neighbors.get(i);
            double min = n_i.get(1).getPair();
            double max = n_i.get(Math.min(knn, n_i.size() - 1)).getPair();
            minSigma.set(Math.min(minSigma.get(), Math.max(min, 1e-9)));// avoid seting 0 as min
            maxSigma.set(Math.max(maxSigma.get(), max));
        }

        // now compute the bandwidth for each datum
        ParallelUtils.run(parallel, N, (i) -> {
            boolean tryAgain = false;
            do {
                tryAgain = false;
                try {
                    double sigma_i = Zeroin.root(1e-2, 100, minSigma.get(), maxSigma.get(), (double x) -> perp(i, nearMe, x, neighbors, vecs, accelCache, dm) - perplexity);

                    sigma[i] = sigma_i;
                } catch (ArithmeticException exception)// perp not in search range?
                {
                    if (maxSigma.get() >= Double.MAX_VALUE / 2) {
                        // Why can't we find a range that fits? Just pick a value..
                        // Not max value, but data is small.. so lets just set someting to break the
                        // loop
                        sigma[i] = 1e100;
                    } else {
                        tryAgain = true;
                        minSigma.set(Math.max(minSigma.get() / 2, 1e-6));
                        maxSigma.set(Math.min(maxSigma.get() * 2, Double.MAX_VALUE / 2));
                    }
                }
            } while (tryAgain);
        });

        ParallelUtils.run(parallel, N, (i) -> {
            for (int j_indx = 0; j_indx < knn; j_indx++) {
                int j = nearMe[i][j_indx];
                nearMePij[i][j_indx] = p_ij(i, j, sigma[i], sigma[j], neighbors, vecs, accelCache, dm);
            }
        });
    }

    /**
     * 
     * @param node      the node to begin computing from
     * @param i
     * @param z
     * @param workSpace the indicies are the accumulated contribution to the
     *                  gradient sans multiplicative terms in the first 2 indices.
     * @return the contribution to the normalizing constant Z
     */
    private double computeF_rep(Quadtree.Node node, int i, double[] z, double[] workSpace) {
        if (node == null || node.N_cell == 0 || node.indx == i)
            return 0;
        /*
         * Original paper says to use the diagonal divided by the squared 2 norm. This
         * dosn't seem to work at all. Tried some different ideas with 0.5 as the
         * threshold until I found one that worked. Squaring the values would normally
         * not be helpful, but since we are working with tiny values it makes them
         * smaller, making it easier to hit the go
         */
        double x = z[i * 2];
        double y = z[i * 2 + 1];
//        double r_cell = node.diagLen();
        double r_cell = Math.max(node.maxX - node.minX, node.maxY - node.minY);
        r_cell *= r_cell;
        double mass_x = node.x_mass / node.N_cell;
        double mass_y = node.y_mass / node.N_cell;
        double dot = (mass_x - x) * (mass_x - x) + (mass_y - y) * (mass_y - y);

        if (node.NW == null || r_cell < theta * dot)// good enough!
        {
            if (node.indx == i)
                return 0;

            double Z = 1.0 / (1.0 + dot);
            double q_cell_Z_sqrd = -node.N_cell * (Z * Z);

            workSpace[0] += q_cell_Z_sqrd * (x - mass_x);
            workSpace[1] += q_cell_Z_sqrd * (y - mass_y);
            return Z * node.N_cell;
        } else// further subdivide
        {
            double Z_sum = 0;
            for (Quadtree.Node child : node)
                Z_sum += computeF_rep(child, i, z, workSpace);
            return Z_sum;
        }
    }

    /**
     * 
     * @param val the value to add to the array
     * @param i   the index of the data point to add to
     * @param j   the dimension index of the embedding
     * @param z   the storage of the embedded vectors
     * @param s   the dimension of the embedding
     */
    private static void inc_z_ij(double val, int i, int j, double[] z, int s) {
        z[i * s + j] += val;
    }

    private static double z_ij(int i, int j, double[] z, int s) {
        return z[i * s + j];
    }

    /**
     * Computes the value of q<sub>ij</sub> Z
     * 
     * @param i
     * @param j
     * @param z
     * @param s
     * @return
     */
    private static double q_ijZ(int i, int j, double[] z, int s) {
        double denom = 1;
        for (int k = 0; k < s; k++) {
            double diff = z_ij(i, k, z, s) - z_ij(j, k, z, s);
            denom += diff * diff;
        }

        return 1.0 / denom;
    }

    /**
     * Computes p<sub>j|i</sub>
     * 
     * @param j
     * @param i
     * @param sigma
     * @param neighbors
     * @return
     */
    private static double p_j_i(int j, int i, double sigma, List<List<? extends VecPaired<Vec, Double>>> neighbors, List<Vec> vecs, List<Double> accelCache, DistanceMetric dm) {
        /*
         * "Because we are only interested in modeling pairwise similarities, we set the
         * value of pi|i to zero" from Visualizing Data using t-SNE
         */
        if (i == j)
            return 0;
        // nearest is self, use taht to get indexed values
        Vec x_j = neighbors.get(j).get(0).getVector();
//        Vec x_i = neighbors.get(i).get(0).getVector();

        final double sigmaSqrdInv = 1 / (2 * (sigma * sigma));

        double numer = 0;
        double denom = 0;
        boolean jIsNearBy = false;
        final List<? extends VecPaired<Vec, Double>> neighbors_i = neighbors.get(i);
        for (int k = 1; k < neighbors_i.size(); k++)// SUM over k != i
        {
            VecPaired<Vec, Double> neighbor_ik = neighbors_i.get(k);
            final double d_ik = neighbor_ik.getPair();
            denom += FastMath.exp(-(d_ik * d_ik) * sigmaSqrdInv);

            if (neighbor_ik.getVector() == x_j)// intentionally doing object equals check - should be same object
            {
                jIsNearBy = true;// yay, dont have to compute the distance ourselves
                numer = FastMath.exp(-(d_ik * d_ik) * sigmaSqrdInv);
            }
        }

        if (!jIsNearBy) {
            double d_ij = dm.dist(i, j, vecs, accelCache);
            numer = FastMath.exp(-(d_ij * d_ij) * sigmaSqrdInv);
        }

        return numer / (denom + 1e-9);
    }

    private static double p_ij(int i, int j, double sigma_i, double sigma_j, List<List<? extends VecPaired<Vec, Double>>> neighbors, List<Vec> vecs, List<Double> accelCache, DistanceMetric dm) {
        return (p_j_i(j, i, sigma_i, neighbors, vecs, accelCache, dm) + p_j_i(i, j, sigma_j, neighbors, vecs, accelCache, dm)) / (2 * neighbors.size());
    }

    /**
     * Computes the perplexity for the specified data point using the given sigma
     * 
     * @param i         the data point to get the perplexity of
     * @param sigma     the bandwidth to use
     * @param neighbors the set of nearest neighbors to consider
     * @return the perplexity 2<sup>H(P<sub>i</sub>)</sup>
     */
    private static double perp(int i, int[][] nearMe, double sigma, List<List<? extends VecPaired<Vec, Double>>> neighbors, List<Vec> vecs, List<Double> accelCache, DistanceMetric dm) {
        // section 2 of Maaten, L. Van Der, & Hinton, G. (2008). Visualizing Data using
        // t-SNE. Journal of Machine Learning Research, 9, 2579–2605.
        double hp = 0;

        for (int j_indx = 0; j_indx < nearMe[i].length; j_indx++) {
            double p_ji = p_j_i(nearMe[i][j_indx], i, sigma, neighbors, vecs, accelCache, dm);

            if (p_ji > 0)
                hp += p_ji * FastMath.log2(p_ji);
        }
        hp *= -1;

        return FastMath.pow2(hp);
    }

    private class Quadtree {
        public Node root;

        public Quadtree(double[] z) {
            this.root = new Node();
            this.root.minX = this.root.minY = Double.POSITIVE_INFINITY;
            this.root.maxX = this.root.maxY = -Double.POSITIVE_INFINITY;

            for (int i = 0; i < z.length / 2; i++) {
                double x = z[i * 2];
                double y = z[i * 2 + 1];
                this.root.minX = Math.min(this.root.minX, x);
                this.root.maxX = Math.max(this.root.maxX, x);
                this.root.minY = Math.min(this.root.minY, y);
                this.root.maxY = Math.max(this.root.maxY, y);
            }

            // done b/c we have <= on min, so to get the edge we need to be slightly larger
            this.root.maxX = Math.nextUp(this.root.maxX);
            this.root.maxY = Math.nextUp(this.root.maxY);

            // nowe start inserting everything
            for (int i = 0; i < z.length / 2; i++)
                root.insert(1, i, z);
        }

        private class Node implements Iterable<Node> {
            public int indx;
            public double x_mass, y_mass;
            public int N_cell;
            public double minX, maxX, minY, maxY;
            public Node NW, NE, SE, SW;

            public Node() {
                indx = -1;
                N_cell = 0;
                x_mass = y_mass = 0;
                NW = NE = SE = SW = null;
            }

            public Node(double minX, double maxX, double minY, double maxY) {
                this();
                this.minX = minX;
                this.maxX = maxX;
                this.minY = minY;
                this.maxY = maxY;
            }

            public boolean contains(int i, double[] z) {
                double x = z[i * 2];
                double y = z[i * 2 + 1];

                return minX <= x && x < maxX && minY <= y && y < maxY;
            }

            public void insert(int weight, int i, double[] z) {
                x_mass += z[i * 2];
                y_mass += z[i * 2 + 1];
                N_cell += weight;
                if (NW == null && indx < 0)// was empy, just set
                    indx = i;
                else {
                    if (indx >= 0) {
                        if (Math.abs(z[indx * 2] - z[i * 2]) < 1e-13 && Math.abs(z[indx * 2 + 1] - z[i * 2 + 1]) < 1e-13) {
                            // near exact same value
                            // just let increase local weight indicate a "heavier" leaf
                            return;
                        }
                    }
                    if (NW == null)// we need to split
                    {
                        double w2 = (maxX - minX) / 2;
                        double h2 = (maxY - minY) / 2;

                        NW = new Node(minX, minX + w2, minY + h2, maxY);
                        NE = new Node(minX + w2, maxX, minY + h2, maxY);
                        SW = new Node(minX, minX + w2, minY, minY + h2);
                        SE = new Node(minX + w2, maxX, minY, minY + h2);

                        for (Node child : this)
                            if (child.contains(this.indx, z)) {
                                child.insert(this.N_cell, this.indx, z);
                                break;
                            }
                        indx = -1;
                    }
                    // and pass this along to our children
                    for (Node child : this)
                        if (child.contains(i, z)) {
                            child.insert(weight, i, z);
                            break;
                        }

                }
            }

            public double diagLen() {
                double w = maxX - minX;
                double h = maxY - minY;
                return Math.sqrt(w * w + h * h);
            }

            @Override
            public Iterator<Node> iterator() {
                if (NW == null)
                    return Collections.emptyIterator();
                else
                    return Arrays.asList(NW, NE, SW, SE).iterator();
            }

        }
    }

    // Current implementation only supports 2D, so hard code it.

    @Override
    public int getTargetDimension() {
        return 2;
    }

    @Override
    public boolean setTargetDimension(int target) {
        return target == 2;
    }
}
